# -*- coding: utf-8 -*-
"""
Created on Sat Mar 19 20:03:42 2022

"""


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
import numpy as np
from torch.autograd import Variable
import time
import pandas as pd

import networkx as nx
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
from tensorboardX import SummaryWriter
from datetime import datetime
from torch.nn.utils import clip_grad_norm_
import copy
import random
import pickle
import os

from models import PhyCRNet,MaskedMSELoss,MaskedL1Loss,weightMSELoss
from data_utils import PrepareConcDataset
from utils import plot_conc,read_logs


#%%load args and model

log_path='exp/exp8/log_conc/'
# dirs={
#       "log_dirs":log_dirs,
#       "model_dirs":model_dirs,
#       "args_dirs":args_dirs,
#       "events_dirs":events_dirs,
#       "orders":orders
#       }
dirs=read_logs('exp/exp_paper/')
#dirs=read_logs(log_path)

for i in range(len(dirs['log_dirs'])):

    with open(dirs['args_dirs'][i],'rb') as f:
        args=pickle.load(f)
        
    args['cuda'] = torch.cuda.is_available()
    
    #if cuda is available use cuda
    if args['cuda']:
        device = torch.device('cuda')
    args['train_ratio']=0.1
        
    train_input, train_label1, train_label1_coe, train_label2, train_label2_mask,\
    val_input, val_label1, val_label1_coe, val_label2, val_label2_mask,\
    mmin_list ,dev_list, conc_name, edge_list ,time_scaler=PrepareConcDataset(args['time_window'],args['res'],
                                                              1,args['buffer'],flag='test')
                                                              
    #reconfig model
    model = PhyCRNet(args['input_dim'],
                     args['hidden_dim'],
                     args['output_dim'],
                     args['input_kernel_size'],
                     args['input_stride'],
                     args['padding_mode'],
                     args['num_layers'], 
                     args['dropout'],
                     args['warmup_updates'],
                     args['tot_updates'],
                     args['peak_lr'],
                     args['end_lr'],
                     1,
                     args['weight_decay'],)
                     #args['optim'])

    train_input,train_label1,train_label1_coe,train_label2,train_label2_mask,\
    val_input,val_label1,val_label1_coe,val_label2,val_label2_mask,\
    test_input,test_label1,test_label1_coe,test_label2,test_label2_mask,\
    mmin_list,dev_list,conc_name,edge_list,time_scaler=PrepareConcDataset(args['time_window'],args['res'],
                                                                          args['pred_day'],args['buffer'],
                                                                          args['gap'],args['test_day'],
                                                                          args['val_ratio'])
      
    #reconfig model
    model = PhyCRNet(args['input_dim'],
                     args['hidden_dim'],
                     args['output_dim'],
                     args['input_kernel_size'],
                     args['input_stride'],
                     args['padding_mode'],
                     args['num_layers'], 
                     args['dropout'],
                     args['warmup_updates'],
                     args['tot_updates'],
                     args['peak_lr'],
                     args['end_lr'],
                     args['weight_decay'])
    
    
    device = torch.device('cuda')
    model=model.to(device=device)
    
    model_path=dirs['model_dirs'][i]
    model.load_state_dict(torch.load(model_path))
    save_path=dirs['log_dirs'][i]
    plot_conc(args,model,train_input,train_label1,train_label1_coe,train_label2,train_label2_mask,
              val_input,val_label1,val_label1_coe,val_label2,val_label2_mask,conc_name,
              save_path,device,dev_list,mmin_list,edge_list)
    plot_conc(args,model,'222','cuda','conc')
    